{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C84HP8qjcBo_"
      },
      "source": [
        "## T5-Plex Demo\n",
        "\n",
        "*Licensed under the Apache License, Version 2.0.*\n",
        "\n",
        "\\\n",
        "\n",
        "\u003ca href=\"https://colab.research.google.com/github/google/uncertainty-baselines/blob/main/experimental/plex/plex_t5_demo.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
        "\n",
        "To run this public colab, please use its `Connect to a local runtime` option by following the **Setup Guide** below.\n",
        "\n",
        "\\\n",
        "\\\n",
        "This notebook demonstrates how one can load the released **T5-Plex** checkpoints from the *Plex: Towards Reliability using Pretrained Large Model Extensions* paper using [JAX](https://jax.readthedocs.io/), and run inference on a single example. \n",
        "\n",
        "For more advanced usage, full training and fine-tuning scripts can be found at https://github.com/google/uncertainty-baselines/tree/main/baselines/jft."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HDfn_-8_uErR"
      },
      "source": [
        "## Setup Guide for Colab Local Runtime\n",
        "(This setup guide is adapted from [T5X](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md) public colab.)\n",
        "\n",
        "Currently the [default public Colab](https://colab.research.google.com/) doesn't support Python version higher than 3.7, which is needed for running Plex-T5 models. Here we provide an alternative: creating a custom jupyter kernel/runtime on a local machine (or a cloud machine), and then use Colab's `Connect to a local runtime` option to run this notebook. [This page](https://research.google.com/colaboratory/local-runtimes.html) contains additional details on how to setup up and use a local runtime.\n",
        "\n",
        "\n",
        "### Prepare python env\n",
        "\n",
        "On a local machine (or a cloud machine), create a Python environment via\n",
        "\n",
        "```\n",
        "sudo apt update\n",
        "sudo apt install python3-pip\n",
        "\n",
        "sudo apt install -y python3.10 python3.10-venv\n",
        "python3.10 -m venv plex_t5_venv\n",
        "```\n",
        "\n",
        "### Install T5X with its dependencies.\n",
        "\n",
        "```\n",
        "source plex_t5_venv/bin/activate\n",
        "python3 -m pip install -U pip setuptools wheel ipython\n",
        "python3 -m pip install flax\n",
        "git clone --branch=main https://github.com/google-research/t5x\n",
        "cd t5x\n",
        "python3 -m pip install -e .\n",
        "cd -\n",
        "```\n",
        "\n",
        "### Install `uncertainty-baselines` with its dependencies.\n",
        "\n",
        "```\n",
        "rm -rf uncertainty-baselines\n",
        "git clone https://github.com/google/uncertainty-baselines.git\n",
        "cp -r uncertainty-baselines/baselines/t5/* .\n",
        "\n",
        "python3 -m pip install ./uncertainty-baselines[models,datasets]\n",
        "```\n",
        "\n",
        "After installing the above, the versions of `jax` and `jaxlib` will be downgraded. Thus we need to re-install the newer versions:\n",
        "\n",
        "```\n",
        "python3 -m pip install jax==0.3.23 jaxlib==0.3.22\n",
        "```\n",
        "Note that you might get warnings about `tensorflow-federated` having incompatible version. This does not affect the usage of this colab.\n",
        "\n",
        "### Install and launch Jupyter\n",
        "At last, we prepare a Jupyter local runtime that can be accessed by our colab notebook. For more detailed official instructions see [here](https://research.google.com/colaboratory/local-runtimes.html).\n",
        "\n",
        "```\n",
        "python3 -m pip install notebook\n",
        "python3 -m pip install --upgrade jupyter_http_over_ws\u003e=0.0.7\n",
        "jupyter serverextension enable --py jupyter_http_over_ws\n",
        "```\n",
        "\n",
        "Use the command below to **launch** the prepared runtime.\n",
        "\n",
        "```\n",
        "jupyter notebook   --NotebookApp.allow_origin='https://colab.research.google.com'   --port=8888   --NotebookApp.port_retries=0\n",
        "```\n",
        "\n",
        "Note that, depending on the installation path of `jupyter`, you might need to swap `jupyter` by `~/plex_t5_venv/bin/jupyter` in the above commands. To find out what's the installation path of your `jupyter` in the virtual environment, use `which jupyter` command.\n",
        "\n",
        "\\\n",
        "\n",
        "You could also swap `allow_origin='https://colab.research.google.com'` with `allow_origin='https://colab.sandbox.google.com'` if needed.\n",
        "\n",
        "\\\n",
        "\n",
        "From the log of the above command, you can see an http link starting with `http://localhost:8888/?token`s. Copy and paste it into the `Connect to a local runtime` option and now you should be able to run this colab.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lu2-L9KLM1lu"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F2WBGHCY8HJd"
      },
      "outputs": [],
      "source": [
        "import functools\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "from jax import random\n",
        "\n",
        "import seqio\n",
        "import t5x\n",
        "from t5x import utils as t5_utils\n",
        "from t5x import partitioning\n",
        "from t5x.examples.t5 import network\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()\n",
        "\n",
        "import decoding\n",
        "import utils\n",
        "import uncertainty_baselines as ub\n",
        "from data.tasks import nalue as nalue_task\n",
        "from uncertainty_baselines.models import t5_be_gp\n",
        "from models import be_models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VBaa2-R4M31e"
      },
      "source": [
        "## Define model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QQ76r3hpUsX0"
      },
      "outputs": [],
      "source": [
        "# Define transformer module.\n",
        "DROPOUT_RATE = 0.0\n",
        "NUM_EMBEDDINGS = 32128\n",
        "COVMAT_MOMENTUM = -1.0\n",
        "BE_ENS_SIZE = 5\n",
        "MEAN_FIELD_FACTOR = 0.0001\n",
        "NORMALIZE_INPUT = True\n",
        "RANDOM_SIGN = 0.5\n",
        "STEPS_PER_EPOCH = None\n",
        "\n",
        "t5_config = network.T5Config(\n",
        "      vocab_size=NUM_EMBEDDINGS,\n",
        "      dropout_rate=DROPOUT_RATE,\n",
        "      dtype='bfloat16',\n",
        "      emb_dim = 1024,\n",
        "      head_dim=64,\n",
        "      logits_via_embedding=False,\n",
        "      mlp_activations=('gelu', 'linear'),\n",
        "      mlp_dim=2816,\n",
        "      num_decoder_layers=24,\n",
        "      num_encoder_layers=24,\n",
        "      num_heads=16)\n",
        "module = t5_be_gp.TransformerBEGp(\n",
        "    config=t5_config,\n",
        "    be_decoder_layers = (-1,),\n",
        "    covmat_momentum = COVMAT_MOMENTUM,\n",
        "    ens_size = BE_ENS_SIZE,\n",
        "    mean_field_factor = MEAN_FIELD_FACTOR,\n",
        "    normalize_input = NORMALIZE_INPUT,\n",
        "    random_sign_init = RANDOM_SIGN,\n",
        "    ridge_penalty = 1.0,\n",
        "    steps_per_epoch = STEPS_PER_EPOCH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xg2dwtppU64U"
      },
      "outputs": [],
      "source": [
        "# Define vocab.\n",
        "sentencepiece_model_file = \"gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model\"\n",
        "VOCABULARY = seqio.SentencePieceVocabulary(sentencepiece_model_file=sentencepiece_model_file)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6skHLZg4VSmD"
      },
      "outputs": [],
      "source": [
        "# Define optimizer.\n",
        "OPTIMIZER = utils.AdafactorGP(decay_rate=0.8, step_offset=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QAHtGrxjVZtK"
      },
      "outputs": [],
      "source": [
        "# Define loss HParam.\n",
        "Z_LOSS = 0.0001\n",
        "LABEL_SMOOTHING = 0.0\n",
        "LOSS_NORMALIZING_FACTOR = 233472\n",
        "LABEL_TOKENS = nalue_task.get_nalue_intent_tokens()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ofsT5AxsVb-q"
      },
      "outputs": [],
      "source": [
        "model = be_models.EncoderDecoderBEGpClassifierModel(\n",
        "    module=module,\n",
        "    decode_fn=functools.partial(\n",
        "        decoding.beam_search, alpha=0., return_token_scores=True),\n",
        "    input_vocabulary=VOCABULARY,\n",
        "    label_smoothing=LABEL_SMOOTHING,\n",
        "    label_tokens=LABEL_TOKENS,\n",
        "    loss_normalizing_factor=LOSS_NORMALIZING_FACTOR,\n",
        "    optimizer_def=OPTIMIZER,\n",
        "    output_vocabulary=VOCABULARY,\n",
        "    z_loss=Z_LOSS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hWl7lKLEM8WU"
      },
      "source": [
        "## Load checkpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x_bQODlUV7Jr"
      },
      "outputs": [],
      "source": [
        "CHECKPOINT_PATH = f'gs://plex-paper/plex_t5_large_c4_to_nalue/'\n",
        "restore_checkpoint_cfg = t5_utils.RestoreCheckpointConfig(\n",
        "    path = CHECKPOINT_PATH, mode = 'specific', use_gda = False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MqStJGauUdNo"
      },
      "outputs": [],
      "source": [
        "partitioner = partitioning.PjitPartitioner(\n",
        "    num_partitions = 1,\n",
        "    logical_axis_rules = partitioning.standard_logical_axis_rules()\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p4CceoB2XtRf"
      },
      "outputs": [],
      "source": [
        "batch_size = 32\n",
        "TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 3}\n",
        "input_shapes = {\n",
        "    'decoder_input_tokens': (batch_size, TASK_FEATURE_LENGTHS['targets']),\n",
        "    'decoder_loss_weights': (batch_size, TASK_FEATURE_LENGTHS['targets']),\n",
        "    'decoder_target_tokens': (batch_size, TASK_FEATURE_LENGTHS['targets']),\n",
        "    'encoder_input_tokens': (batch_size, TASK_FEATURE_LENGTHS['inputs'])\n",
        "}\n",
        "\n",
        "# Create train state initializer.\n",
        "train_state_initializer = t5_utils.TrainStateInitializer(\n",
        "      optimizer_def=None,  # Do not load optimizer state.\n",
        "      init_fn=model.get_initial_variables,\n",
        "      input_shapes=input_shapes,\n",
        "      partitioner=partitioner)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OdHmN6ytytSJ"
      },
      "outputs": [],
      "source": [
        "# Restore train state from checkpoint.\n",
        "train_state = train_state_initializer.from_checkpoint([restore_checkpoint_cfg])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ty8MscAzNBaz"
      },
      "source": [
        "## Run inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lziDAHrpa-mF"
      },
      "outputs": [],
      "source": [
        "input_text = \"can you please provide me with assistance in moving money from one account to another\" #@param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eeSzln7Aa_4K"
      },
      "outputs": [],
      "source": [
        "infer_step_jit = jax.jit(model.predict_batch_with_aux)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4p8v31PKbDzg"
      },
      "outputs": [],
      "source": [
        "input_tokenized = VOCABULARY.encode(input_text)\n",
        "input_padded = np.pad(input_tokenized, (0, 512 - len(input_tokenized)))\n",
        "infer_batch = {}\n",
        "infer_batch['encoder_input_tokens'] = jax.numpy.expand_dims(input_padded,\n",
        "                                                            axis=0)\n",
        "infer_batch['decoder_input_tokens'] = np.zeros((1, 3), dtype=np.int32)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S7dHDdZXbW8u"
      },
      "outputs": [],
      "source": [
        "# Runs inference on a batch via partitioned_infer_step.\n",
        "rng = jax.random.PRNGKey(0)\n",
        "batch_result = infer_step_jit(train_state.params, infer_batch, rng)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k9DnxfMnbb9G"
      },
      "outputs": [],
      "source": [
        "VOCABULARY.decode_tf(batch_result[0])"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "provenance": [
        {
          "file_id": "1jgAFJyHowVuzr2yNm6zBy183XxbEbyE1",
          "timestamp": 1665078573501
        },
        {
          "file_id": "1amsX-i0K0R4TRerm7TtOwjwuTsF2ZtKx",
          "timestamp": 1665006819770
        }
      ],
      "toc_visible": true
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
